//
//  InterpGradTest.cpp
//  MNNTests
//
//  Created by MNN on 2022/08/18.
//  Copyright © 2018, Alibaba Group Holding Limited
//

#include <MNN/expr/Expr.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include "MNNTestSuite.h"
#include "TestUtils.h"
#include "../tools/train/source/grad/OpGrad.hpp"

using namespace MNN;
using namespace MNN::Express;

class InterpGradTest : public MNNTestCase {
public:
    char name[20] = "Interp";
    virtual ~InterpGradTest() = default;

    virtual bool run(int precision) {
        std::vector<int> shape = {2, 3, 2, 3};
        const int len = shape[0] * shape[1] * shape[2] * shape[3];
        auto input = _Input(shape, NCHW);
        const float inpudata[] = {  0.5500, 0.6721, 0.4343, 0.8518, 0.9456, 0.6444, 0.5927, 0.4439, 0.9329,
                                    0.1434, 0.6933, 0.0180, 0.3173, 0.2903, 0.4159, 0.8706, 0.1812, 0.5890,
                                    0.3834, 0.0335, 0.9997, 0.7504, 0.5379, 0.9836, 0.3202, 0.4824, 0.9982,
                                    0.8029, 0.2889, 0.8386, 0.2282, 0.6912, 0.2678, 0.9031, 0.7055, 0.9389};
        auto inputPtr = input->writeMap<float>();
        memcpy(inputPtr, inpudata, len * sizeof(float));

        float wScale = 2.5;
        float hScale = 2.5;
        int outputW = int(floor(wScale * 3));
        int outputH = int(floor(hScale * 2));

        int mode = 1; // 1:near 2: bilinear 3: cubic 4: nearest_round
        bool alignCorners = false;
        float scales[] = {1.0, 1.0, hScale, wScale};
        auto scaleVar = _Const((void*)scales, {4}, NCHW);

        auto output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);
        auto outputPtr = output->readMap<float>();

        const int len2 = shape[0] * shape[1] * outputH * outputW;

        std::vector<float> outputTorch = { 0.5500, 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500,
                                    0.5500, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.5500, 0.6721,
                                    0.6721, 0.4343, 0.4343, 0.8518, 0.8518, 0.8518, 0.9456, 0.9456, 0.6444,
                                    0.6444, 0.8518, 0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444, 0.5927,
                                    0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.5927,
                                    0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.5927, 0.4439, 0.4439,
                                    0.9329, 0.9329, 0.1434, 0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180,
                                    0.1434, 0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180, 0.3173, 0.3173,
                                    0.3173, 0.2903, 0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.3173, 0.2903,
                                    0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.3173, 0.2903, 0.2903, 0.4159,
                                    0.4159, 0.8706, 0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890, 0.8706,
                                    0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890, 0.3834, 0.3834, 0.3834,
                                    0.0335, 0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.3834, 0.0335, 0.0335,
                                    0.9997, 0.9997, 0.3834, 0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997,
                                    0.7504, 0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836, 0.7504, 0.7504,
                                    0.7504, 0.5379, 0.5379, 0.9836, 0.9836, 0.3202, 0.3202, 0.3202, 0.4824,
                                    0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.3202, 0.4824, 0.4824, 0.9982,
                                    0.9982, 0.3202, 0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982, 0.8029,
                                    0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.8029,
                                    0.2889, 0.2889, 0.8386, 0.8386, 0.2282, 0.2282, 0.2282, 0.6912, 0.6912,
                                    0.2678, 0.2678, 0.2282, 0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678,
                                    0.2282, 0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678, 0.9031, 0.9031,
                                    0.9031, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.9031, 0.7055,
                                    0.7055, 0.9389, 0.9389};

        for (int i = 0, count = 0; i < len2; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s type 1 output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\ttype 1 output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        auto opExpr = output->expr().first;
        auto grad = OpGrad::get(opExpr->get()->type());

        float outputDiff[] = {  0.4334, 0.0341, 0.2641, 0.7900, 0.5033, 0.5753, 0.0180, 0.7617, 0.3299,
                                    0.9388, 0.7645, 0.9430, 0.7312, 0.1058, 0.4154, 0.7537, 0.3959, 0.2058,
                                    0.9084, 0.9178, 0.6203, 0.3946, 0.5286, 0.0053, 0.5959, 0.4805, 0.3953,
                                    0.8146, 0.8543, 0.1426, 0.4022, 0.8199, 0.7822, 0.3160, 0.5057, 0.8435,
                                    0.5449, 0.0964, 0.4719, 0.3557, 0.1786, 0.8186, 0.0859, 0.2833, 0.5462,
                                    0.1870, 0.9203, 0.1523, 0.3556, 0.9206, 0.2185, 0.5502, 0.8321, 0.5941,
                                    0.3160, 0.0663, 0.8522, 0.8215, 0.3595, 0.2714, 0.6255, 0.9103, 0.2248,
                                    0.4765, 0.2330, 0.4213, 0.3474, 0.7129, 0.1307, 0.2414, 0.7421, 0.1453,
                                    0.5165, 0.7668, 0.1646, 0.0379, 0.1988, 0.1783, 0.3200, 0.5802, 0.7501,
                                    0.5057, 0.9157, 0.2080, 0.9982, 0.6694, 0.3964, 0.3710, 0.9381, 0.9157,
                                    0.2548, 0.2127, 0.8212, 0.5140, 0.3528, 0.2028, 0.9128, 0.3492, 0.9882,
                                    0.0330, 0.4107, 0.5150, 0.8750, 0.1118, 0.1271, 0.5068, 0.4232, 0.0709,
                                    0.2711, 0.0418, 0.5329, 0.8123, 0.0119, 0.1818, 0.2264, 0.6342, 0.1863,
                                    0.8303, 0.2253, 0.8572, 0.0520, 0.0255, 0.4094, 0.3164, 0.2758, 0.5764,
                                    0.7998, 0.7261, 0.9420, 0.8043, 0.7131, 0.3567, 0.1961, 0.3868, 0.4668,
                                    0.8830, 0.8475, 0.2829, 0.0681, 0.3372, 0.9303, 0.0397, 0.0962, 0.3651,
                                    0.1226, 0.7876, 0.4374, 0.1730, 0.2058, 0.7499, 0.8105, 0.5794, 0.7401,
                                    0.3478, 0.8476, 0.8795, 0.7856, 0.6042, 0.4180, 0.4664, 0.8128, 0.6839,
                                    0.9811, 0.7328, 0.9305, 0.1411, 0.4011, 0.4810, 0.5414, 0.6038, 0.1644,
                                    0.1686, 0.2125, 0.1554, 0.8285, 0.6496, 0.0667, 0.7326, 0.9510, 0.1087,
                                    0.4501, 0.8744, 0.3976, 0.8691, 0.7303, 0.2784, 0.4464, 0.8928, 0.6532,
                                    0.4175, 0.5971, 0.7475, 0.1091, 0.3149, 0.3717, 0.5579, 0.5649, 0.6624,
                                    0.8024, 0.1316, 0.8202, 0.7971, 0.6213, 0.9040, 0.9452, 0.9925, 0.4661,
                                    0.7995, 0.0764, 0.0370};
        auto inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});

        std::vector<float> expectedOutput = {   4.3270, 4.1150, 2.9684, 2.3276, 2.6785, 2.0316, 4.0895, 3.3611, 1.8874,
                                                3.1640, 1.9572, 1.5072, 4.5464, 3.4963, 2.5309, 2.9798, 1.9456, 1.5009,
                                                2.3557, 1.8592, 3.2530, 4.2045, 2.6478, 0.9581, 4.7076, 2.8998, 3.5921,
                                                3.7074, 1.4527, 1.8660, 5.2080, 2.2085, 3.8001, 4.8714, 2.2174, 1.5318};
        auto gotOutput = inputGrad[0]->readMap<float>();

        for (int i = 0; i < len; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s type 1 grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\ttype 1 grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        // TODO: inference of this mode is not aligned with pytorch
        mode = 4;
        output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);

        outputPtr = output->readMap<float>();

        outputTorch = { 0.5500, 0.5500, 0.6721, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500,
                        0.6721, 0.6721, 0.6721, 0.4343, 0.4343, 0.8518, 0.8518, 0.9456, 0.9456,
                        0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456, 0.9456, 0.9456, 0.6444,
                        0.6444, 0.8518, 0.8518, 0.9456, 0.9456, 0.9456, 0.6444, 0.6444, 0.5927,
                        0.5927, 0.4439, 0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439,
                        0.4439, 0.4439, 0.9329, 0.9329, 0.1434, 0.1434, 0.6933, 0.6933, 0.6933,
                        0.0180, 0.0180, 0.1434, 0.1434, 0.6933, 0.6933, 0.6933, 0.0180, 0.0180,
                        0.1434, 0.1434, 0.6933, 0.6933, 0.6933, 0.0180, 0.0180, 0.3173, 0.3173,
                        0.2903, 0.2903, 0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903, 0.2903,
                        0.2903, 0.4159, 0.4159, 0.8706, 0.8706, 0.1812, 0.1812, 0.1812, 0.5890,
                        0.5890, 0.8706, 0.8706, 0.1812, 0.1812, 0.1812, 0.5890, 0.5890, 0.8706,
                        0.8706, 0.1812, 0.1812, 0.1812, 0.5890, 0.5890, 0.3834, 0.3834, 0.0335,
                        0.0335, 0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335, 0.0335, 0.0335,
                        0.9997, 0.9997, 0.7504, 0.7504, 0.5379, 0.5379, 0.5379, 0.9836, 0.9836,
                        0.7504, 0.7504, 0.5379, 0.5379, 0.5379, 0.9836, 0.9836, 0.7504, 0.7504,
                        0.5379, 0.5379, 0.5379, 0.9836, 0.9836, 0.3202, 0.3202, 0.4824, 0.4824,
                        0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824, 0.4824, 0.4824, 0.9982,
                        0.9982, 0.8029, 0.8029, 0.2889, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029,
                        0.8029, 0.2889, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889,
                        0.2889, 0.2889, 0.8386, 0.8386, 0.2282, 0.2282, 0.6912, 0.6912, 0.6912,
                        0.2678, 0.2678, 0.2282, 0.2282, 0.6912, 0.6912, 0.6912, 0.2678, 0.2678,
                        0.9031, 0.9031, 0.7055, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031,
                        0.7055, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055, 0.7055,
                        0.7055, 0.9389, 0.9389};

        for (int i = 0, count = 0; i < len2; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s type 4 output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                // return false;
            } else {
                // printf("\ttype 4 output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());

        inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});

        expectedOutput = {  1.5591, 4.2037, 1.4303, 3.0892, 4.5961, 3.5697, 1.7576, 2.5775, 1.5051,
                            3.5223, 4.7144, 1.8895, 1.3857, 3.2839, 1.3604, 3.7227, 4.5758, 2.6714,
                            1.1237, 1.4307, 2.4008, 3.2887, 5.2241, 1.8103, 1.3488, 2.7237, 2.3129,
                            4.5373, 4.1577, 3.1452, 1.9830, 3.2474, 2.8705, 4.0911, 5.1838, 2.4614};
        gotOutput = inputGrad[0]->readMap<float>();

        for (int i = 0; i < len; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s type 4 grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\ttype 4 grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = 2;
        alignCorners = false;
        output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);
        outputPtr = output->readMap<float>();

        outputTorch = { 0.5500, 0.5622, 0.6111, 0.6599, 0.6008, 0.5056, 0.4343, 0.5802, 0.5921,
                        0.6398, 0.6875, 0.6262, 0.5286, 0.4553, 0.7009, 0.7117, 0.7549, 0.7981,
                        0.7280, 0.6202, 0.5394, 0.8216, 0.8313, 0.8699, 0.9086, 0.8298, 0.7118,
                        0.6234, 0.8518, 0.8612, 0.8987, 0.9362, 0.8552, 0.7348, 0.6444, 0.5927,
                        0.5778, 0.5183, 0.4588, 0.5906, 0.7862, 0.9329, 0.5478, 0.5399, 0.5083,
                        0.4767, 0.5806, 0.7296, 0.8414, 0.3681, 0.3881, 0.4683, 0.5485, 0.5407,
                        0.5034, 0.4755, 0.1883, 0.2363, 0.4283, 0.6204, 0.5007, 0.2772, 0.1095,
                        0.1434, 0.1984, 0.4184, 0.6383, 0.4907, 0.2206, 0.0180, 0.3173, 0.3146,
                        0.3038, 0.2930, 0.3280, 0.3782, 0.4159, 0.3726, 0.3633, 0.3260, 0.2887,
                        0.3255, 0.3871, 0.4332, 0.5939, 0.5581, 0.4148, 0.2716, 0.3158, 0.4224,
                        0.5024, 0.8153, 0.7530, 0.5037, 0.2544, 0.3060, 0.4578, 0.5717, 0.8706,
                        0.8017, 0.5259, 0.2501, 0.3035, 0.4667, 0.5890, 0.3834, 0.3484, 0.2084,
                        0.0685, 0.3234, 0.7098, 0.9997, 0.4201, 0.3865, 0.2520, 0.1176, 0.3582,
                        0.7238, 0.9981, 0.5669, 0.5388, 0.4263, 0.3138, 0.4975, 0.7799, 0.9916,
                        0.7137, 0.6911, 0.6006, 0.5101, 0.6368, 0.8359, 0.9852, 0.7504, 0.7291,
                        0.6442, 0.5591, 0.6716, 0.8499, 0.9836, 0.3202, 0.3364, 0.4013, 0.4662,
                        0.6371, 0.8435, 0.9982, 0.3685, 0.3779, 0.4158, 0.4536, 0.6188, 0.8265,
                        0.9822, 0.5616, 0.5440, 0.4736, 0.4032, 0.5455, 0.7586, 0.9184, 0.7546,
                        0.7100, 0.5314, 0.3529, 0.4721, 0.6907, 0.8546, 0.8029, 0.7515, 0.5459,
                        0.3403, 0.4538, 0.6737, 0.8386, 0.2282, 0.2745, 0.4597, 0.6449, 0.5642,
                        0.3948, 0.2678, 0.2957, 0.3354, 0.4942, 0.6529, 0.5853, 0.4422, 0.3349,
                        0.5656, 0.5789, 0.6320, 0.6851, 0.6698, 0.6319, 0.6033, 0.8356, 0.8225,
                        0.7698, 0.7172, 0.7544, 0.8215, 0.8718, 0.9031, 0.8833, 0.8043, 0.7253,
                        0.7755, 0.8689, 0.9389};

        for (int i = 0, count = 0; i < len2; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s type 2 alignCorners false output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\ttype 2 alignCorners false output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());

        inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});

        expectedOutput = {  2.8685, 4.0238, 2.2734, 2.9216, 3.4296, 2.9312, 2.9163, 2.7736, 2.0526,
                            3.3512, 3.0745, 1.7981, 2.9985, 3.4252, 1.8688, 3.3955, 3.2112, 2.1006,
                            1.9465, 2.0171, 2.4722, 3.6515, 3.6092, 1.5819, 3.0323, 3.1607, 2.6740,
                            4.1352, 2.7624, 2.4610, 3.4138, 3.1663, 3.0918, 4.5425, 3.4122, 2.2106};
        gotOutput = inputGrad[0]->readMap<float>();

        for (int i = 0; i < len; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s type 2 alignCorners false grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\ttype 2 alignCorners false grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = 2;
        alignCorners = true;
        output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);
        outputPtr = output->readMap<float>();

        outputTorch = { 0.5500, 0.5907, 0.6314, 0.6721, 0.5928, 0.5136, 0.4343, 0.6255, 0.6638,
                        0.7021, 0.7405, 0.6559, 0.5714, 0.4868, 0.7009, 0.7369, 0.7729, 0.8088,
                        0.7190, 0.6292, 0.5394, 0.7764, 0.8100, 0.8436, 0.8772, 0.7821, 0.6870,
                        0.5919, 0.8518, 0.8831, 0.9143, 0.9456, 0.8452, 0.7448, 0.6444, 0.5927,
                        0.5431, 0.4935, 0.4439, 0.6069, 0.7699, 0.9329, 0.4804, 0.4890, 0.4976,
                        0.5063, 0.5722, 0.6382, 0.7042, 0.3681, 0.4349, 0.5017, 0.5686, 0.5375,
                        0.5065, 0.4755, 0.2557, 0.3808, 0.5059, 0.6309, 0.5029, 0.3748, 0.2467,
                        0.1434, 0.3267, 0.5100, 0.6933, 0.4682, 0.2431, 0.0180, 0.3173, 0.3083,
                        0.2993, 0.2903, 0.3322, 0.3740, 0.4159, 0.4556, 0.3914, 0.3272, 0.2630,
                        0.3284, 0.3938, 0.4592, 0.5939, 0.4745, 0.3551, 0.2358, 0.3246, 0.4136,
                        0.5024, 0.7323, 0.5577, 0.3831, 0.2085, 0.3209, 0.4333, 0.5457, 0.8706,
                        0.6408, 0.4110, 0.1812, 0.3171, 0.4531, 0.5890, 0.3834, 0.2668, 0.1501,
                        0.0335, 0.3556, 0.6776, 0.9997, 0.4751, 0.3700, 0.2648, 0.1596, 0.4383,
                        0.7170, 0.9957, 0.5669, 0.4732, 0.3794, 0.2857, 0.5210, 0.7563, 0.9916,
                        0.6586, 0.5764, 0.4941, 0.4118, 0.6037, 0.7957, 0.9876, 0.7504, 0.6796,
                        0.6087, 0.5379, 0.6865, 0.8350, 0.9836, 0.3202, 0.3743, 0.4283, 0.4824,
                        0.6543, 0.8263, 0.9982, 0.4409, 0.4386, 0.4363, 0.4340, 0.6088, 0.7835,
                        0.9583, 0.5616, 0.5029, 0.4443, 0.3856, 0.5632, 0.7408, 0.9184, 0.6822,
                        0.5672, 0.4523, 0.3373, 0.5177, 0.6981, 0.8785, 0.8029, 0.6316, 0.4602,
                        0.2889, 0.4721, 0.6554, 0.8386, 0.2282, 0.3825, 0.5369, 0.6912, 0.5501,
                        0.4089, 0.2678, 0.3969, 0.4962, 0.5955, 0.6948, 0.6084, 0.5220, 0.4356,
                        0.5656, 0.6099, 0.6541, 0.6984, 0.6667, 0.6350, 0.6033, 0.7344, 0.7236,
                        0.7127, 0.7019, 0.7250, 0.7481, 0.7711, 0.9031, 0.8372, 0.7714, 0.7055,
                        0.7833, 0.8611, 0.9389};

        for (int i = 0, count = 0; i < len2; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s type 2 alignCorners true output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\ttype 2 alignCorners true output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());

        inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});

        expectedOutput = {  2.2272, 4.4075, 2.3271, 2.4936, 4.0926, 2.9002, 2.5863, 3.2598, 2.1267,
                            2.6511, 3.5678, 1.7748, 2.4711, 3.9431, 1.8645, 2.7803, 3.8429, 2.0980,
                            1.8186, 2.5404, 2.4130, 2.6931, 4.1895, 1.6237, 2.6524, 3.7168, 2.6095,
                            3.1733, 3.5839, 2.4896, 2.7044, 3.9169, 3.0705, 3.6579, 4.2589, 2.2286};
        gotOutput = inputGrad[0]->readMap<float>();

        for (int i = 0; i < len; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s type 2 alignCorners true grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\ttype 2 alignCorners true grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        mode = 3;
        alignCorners = true;
        output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);
        outputPtr = output->readMap<float>();

        outputTorch = {     0.5500, 0.6016, 0.6601, 0.6721, 0.6108, 0.5159, 0.4343, 0.6184, 0.6688,
                            0.7257, 0.7341, 0.6675, 0.5677, 0.4819, 0.7009, 0.7499, 0.8048, 0.8088,
                            0.7360, 0.6302, 0.5393, 0.7834, 0.8309, 0.8840, 0.8836, 0.8045, 0.6927,
                            0.5968, 0.8518, 0.8981, 0.9495, 0.9456, 0.8612, 0.7444, 0.6444, 0.5927,
                            0.5187, 0.4364, 0.4439, 0.5813, 0.7707, 0.9329, 0.4909, 0.4814, 0.4724,
                            0.5004, 0.5724, 0.6552, 0.7256, 0.3681, 0.4364, 0.5158, 0.5686, 0.5616,
                            0.5159, 0.4754, 0.2452, 0.3913, 0.5592, 0.6368, 0.5508, 0.3766, 0.2253,
                            0.1434, 0.3540, 0.5952, 0.6933, 0.5418, 0.2611, 0.0180, 0.3173, 0.3018,
                            0.2848, 0.2903, 0.3268, 0.3749, 0.4159, 0.4427, 0.3764, 0.3003, 0.2656,
                            0.3056, 0.3856, 0.4551, 0.5940, 0.4664, 0.3189, 0.2357, 0.2799, 0.3986,
                            0.5024, 0.7452, 0.5564, 0.3375, 0.2059, 0.2542, 0.4116, 0.5498, 0.8706,
                            0.6309, 0.3529, 0.1812, 0.2330, 0.4223, 0.5890, 0.3834, 0.2196, 0.0363,
                            0.0335, 0.2988, 0.6761, 0.9997, 0.4665, 0.3191, 0.1539, 0.1478, 0.3794,
                            0.7113, 0.9961, 0.5669, 0.4392, 0.2958, 0.2857, 0.4767, 0.7538, 0.9917,
                            0.6673, 0.5592, 0.4377, 0.4236, 0.5740, 0.7963, 0.9872, 0.7504, 0.6587,
                            0.5553, 0.5379, 0.6546, 0.8315, 0.9836, 0.3202, 0.3426, 0.3740, 0.4824,
                            0.6628, 0.8448, 0.9982, 0.4296, 0.4033, 0.3776, 0.4386, 0.6044, 0.7977,
                            0.9620, 0.5615, 0.4766, 0.3818, 0.3856, 0.5338, 0.7409, 0.9184, 0.6935,
                            0.5498, 0.3861, 0.3327, 0.4633, 0.6841, 0.8748, 0.8029, 0.6105, 0.3896,
                            0.2889, 0.4048, 0.6370, 0.8386, 0.2282, 0.3975, 0.5925, 0.6912, 0.6094,
                            0.4268, 0.2678, 0.3811, 0.4950, 0.6263, 0.6944, 0.6428, 0.5237, 0.4198,
                            0.5656, 0.6127, 0.6671, 0.6984, 0.6832, 0.6406, 0.6033, 0.7502, 0.7304,
                            0.7080, 0.7023, 0.7236, 0.7576, 0.7869, 0.9031, 0.8279, 0.7418, 0.7055,
                            0.7570, 0.8544, 0.9389};

        for (int i = 0, count = 0; i < len2; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s type 3 alignCorners true output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // printf("\ttype 3 alignCorners true output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());

        inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});

        expectedOutput = {  1.9391, 4.9129, 2.1415, 2.2126, 4.4911, 2.7508, 2.3607, 3.6276, 1.9484,
                            2.4287, 4.0120, 1.5888, 2.2588, 4.3451, 1.6770, 2.5254, 4.2834, 1.9103,
                            1.6577, 2.7228, 2.3390, 2.5627, 4.6302, 1.3659, 2.4378, 4.0773, 2.4462,
                            3.0859, 3.9371, 2.2413, 2.5186, 4.3590, 2.8111, 3.4074, 4.8708, 1.8704};
        gotOutput = inputGrad[0]->readMap<float>();

        for (int i = 0; i < len; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s type 3 alignCorners true grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // printf("\ttype 3 alignCorners true grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        // TODO: inference of these arguments combination is wrong
        mode = 3;
        alignCorners = false;
        output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);
        outputPtr = output->readMap<float>();

        outputTorch = {     0.5029,  0.5286,  0.6010,  0.6457,  0.5914,  0.4814,  0.3971,  0.5615,
                            0.5862,  0.6574,  0.6995,  0.6399,  0.5250,  0.4368,  0.6890,  0.7116,
                            0.7801,  0.8164,  0.7456,  0.6196,  0.5230,  0.8165,  0.8369,  0.9029,
                            0.9333,  0.8512,  0.7143,  0.6092,  0.8751,  0.8946,  0.9593,  0.9870,
                            0.8998,  0.7578,  0.6488,  0.6671,  0.6195,  0.4714,  0.3974,  0.5638,
                            0.8509,  1.0713,  0.5659,  0.5457,  0.4732,  0.4479,  0.5641,  0.7438,
                            0.8818,  0.3459,  0.3851,  0.4771,  0.5578,  0.5647,  0.5109,  0.4698,
                            0.1259,  0.2246,  0.4809,  0.6677,  0.5654,  0.2781,  0.0578,  0.0247,
                            0.1507,  0.4827,  0.7182,  0.5657,  0.1710, -0.1317,  0.2512,  0.2594,
                            0.2705,  0.2928,  0.3338,  0.3726,  0.4026,  0.3716,  0.3550,  0.3081,
                            0.2790,  0.3139,  0.3848,  0.4391,  0.6334,  0.5628,  0.3898,  0.2489,
                            0.2707,  0.4111,  0.5187,  0.8952,  0.7706,  0.4716,  0.2189,  0.2274,
                            0.4375,  0.5982,  1.0157,  0.8661,  0.5092,  0.2050,  0.2076,  0.4496,
                            0.6347,  0.3832,  0.3061,  0.0645, -0.0544,  0.2232,  0.6986,  1.0637,
                            0.4508,  0.3795,  0.1576,  0.0465,  0.2952,  0.7247,  1.0545,  0.5979,
                            0.5391,  0.3601,  0.2659,  0.4517,  0.7814,  1.0345,  0.7450,  0.6987,
                            0.5626,  0.4852,  0.6081,  0.8381,  1.0146,  0.8126,  0.7721,  0.6558,
                            0.5861,  0.6801,  0.8642,  1.0054,  0.2409,  0.2829,  0.3374,  0.4532,
                            0.6727,  0.8841,  1.0469,  0.3480,  0.3650,  0.3645,  0.4263,  0.6230,
                            0.8455,  1.0166,  0.5809,  0.5435,  0.4237,  0.3677,  0.5149,  0.7615,
                            0.9508,  0.8139,  0.7220,  0.4828,  0.3091,  0.4068,  0.6774,  0.8849,
                            0.9210,  0.8041,  0.5100,  0.2822,  0.3571,  0.6388,  0.8546,  0.0947,
                            0.2011,  0.4682,  0.6758,  0.6104,  0.3575,  0.1637,  0.2385,  0.3196,
                            0.5226,  0.6813,  0.6343,  0.4452,  0.3004,  0.5510,  0.5772,  0.6409,
                            0.6932,  0.6865,  0.6361,  0.5976,  0.8636,  0.8348,  0.7592,  0.7052,
                            0.7386,  0.8270,  0.8948,  1.0073,  0.9533,  0.8136,  0.7107,  0.7626,
                            0.9148,  1.0315};

        for (int i = 0, count = 0; i < len2; ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s type 3 alignCorners false output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                // return false;
            } else {
                // printf("\ttype 3 alignCorners false output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        opExpr = output->expr().first;
        grad = OpGrad::get(opExpr->get()->type());

        inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});

        expectedOutput = {  2.6602, 4.3750, 2.0222, 2.8827, 3.6077, 2.9003, 2.9207, 2.8721, 2.0044,
                            3.2637, 3.3148, 1.5907, 2.9417, 3.6414, 1.6572, 3.3495, 3.4318, 1.9783,
                            1.8646, 1.8953, 2.5679, 3.7284, 3.9885, 1.2336, 2.9781, 3.3275, 2.5594,
                            4.2535, 2.7625, 2.3446, 3.2755, 3.2857, 3.0657, 4.6292, 3.6900, 1.8911};
        gotOutput = inputGrad[0]->readMap<float>();

        for (int i = 0; i < len; ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.0001) {
                MNN_ERROR("%d: %s type 3 alignCorners false grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
                // return false;
            } else {
                // printf("\ttype 3 alignCorners false grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
            }
        }


        return true;
    }
};

MNNTestSuiteRegister(InterpGradTest, "grad/interp");
